import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ast
import yaml
import seaborn as sns

with open("task_specs", "r") as file:
    task_specs = yaml.load(file.read())

#UTILITIES 
def list_flatten(input_list):
    return [item for sublist in input_list for item in sublist]

def gini(x):
    #https://stackoverflow.com/questions/39512260/calculating-gini-coefficient-in-python-numpy
    mad = np.abs(np.subtract.outer(x, x)).mean()
    rmad = mad/np.mean(x)
    return rmad/2

#DATA PROCESSING
def process_intermediate_solution(intermediate_solution, payoff):
    nominal_score = 0
    for room in intermediate_solution['solution']:
        if room != "deck":
            for student in intermediate_solution['solution'][room]: 
                nominal_score += payoff[student][int(room)]

    real_score = nominal_score - 100 * intermediate_solution['nConstraintsViolated']

    return nominal_score, real_score, intermediate_solution['nConstraintsViolated'], intermediate_solution['completeSolution'],intermediate_solution['nConstraintsViolated']==0, intermediate_solution['at']

def generate_round_data(row, phase): 
    """
    Phase must be '1' or '2'
    """
    
    round_payoff_dict = task_specs['phase_{}'.format(phase)][row.task_index]['payoff']
    df_round_data = pd.DataFrame(data=np.array([process_intermediate_solution(x, round_payoff_dict) for x in row.parsed_intermediateSolutions]), 
                                 columns=["nominal_score", "real_score", "num_violations", "is_complete", "is_feasible", "timestamp"])

    #Data types
    df_round_data['nominal_score'] = df_round_data['nominal_score'].astype(int)
    df_round_data['real_score'] = df_round_data['real_score'].astype(int)
    df_round_data['num_violations'] = df_round_data['num_violations'].astype(int)
    
    df_round_data['is_complete'] = [ast.literal_eval(x) for x in df_round_data['is_complete']]
    df_round_data['is_feasible'] = [ast.literal_eval(x) for x in df_round_data['is_feasible']]
    
    df_round_data['timestamp'] = pd.to_datetime(df_round_data['timestamp'])
    
    action_takers = np.array([x['subjectId'] for x in yaml.load(row.log) if x['verb'] == "movedStudent"])
    
    if len(action_takers) == len(df_round_data):
        df_round_data['subject_id'] = action_takers
    else: 
        df_round_data['subject_id'] = None
    
    return df_round_data

def phase1_round_features(dev_solution_data, dev_log, dev_task_index, task_specs, phase):
    # Data prep
    dev_solution_data['delta_violations'] = dev_solution_data['num_violations'].diff()
    dev_solution_data.loc[0,'delta_violations'] = dev_solution_data.loc[0,'num_violations']
    dev_solution_data['delta_nominal'] = dev_solution_data['nominal_score'].diff().fillna(dev_solution_data['nominal_score'])
    dev_solution_data['delta_real'] = dev_solution_data['real_score'].diff().fillna(dev_solution_data['real_score'])
    dev_solution_data['nominal_gained_per_violation'] = dev_solution_data['delta_nominal'] / dev_solution_data['delta_violations']
    dev_solution_data['real_gained_per_violation'] = dev_solution_data['delta_real'] / dev_solution_data['delta_violations']

    #Basic features
    num_inter_soln = len(dev_solution_data)

    #Adding this because round BbPxDKJjA52xHhYx2 didn't have a complete solution
    try:
        index_first_complete = dev_solution_data.query("is_complete == 1").index[0]
    except: 
        index_first_complete = None


    #Suggested features 
    # num_complete_soln = dev_solution_data.is_complete.sum()
    # num_feasible_soln = len(dev_solution_data.query("num_violations == 0"))
    # num_feasible_and_complete = len(dev_solution_data.query("num_violations == 0 and is_complete == 1"))
    # num_greedy_exceed_max = np.sum(dev_solution_data['nominal_score'] > task_specs['phase_{}'.format(phase)][dev_task_index]['optimal'])
    # bool_optimal_found = any(dev_solution_data['real_score'] == task_specs['phase_{}'.format(phase)][dev_task_index]['optimal'])
    # bool_submitted_highest_found = cdf_submitted_score == 1
    # bool_submitted_highest_greedy = all(dev_solution_data['nominal_score'].values[-1] >= dev_solution_data['nominal_score'].values)
    # bool_submitted_is_complete = dev_solution_data['is_complete'].values[-1] == 1
    # bool_submitted_is_feasible = dev_solution_data['num_violations'].values[-1] == 0
    # bool_submitted_feasible_complete = bool_submitted_is_complete and bool_submitted_is_feasible

    # avg_real_gained_per_violation = dev_solution_data.query("delta_violations > 0")['real_gained_per_violation'].mean()
    # avg_real_gained_per_violation_incomplete = dev_solution_data.query("delta_violations > 0 and is_complete == 0")['real_gained_per_violation'].mean()
    # avg_real_gained_per_violation_complete = dev_solution_data.query("delta_violations > 0 and is_complete == 1")['real_gained_per_violation'].mean()



    round_summary = {"ROUNDFEAT_SOLNS_num_inter_soln":num_inter_soln, 
                     "ROUNDFEAT_SOLNS_index_first_complete":index_first_complete,
                     }

    return round_summary

def round_features(dev_solution_data, dev_log, dev_task_index, task_specs, phase):
    # Data prep
    dev_solution_data['delta_violations'] = dev_solution_data['num_violations'].diff()
    dev_solution_data.loc[0,'delta_violations'] = dev_solution_data.loc[0,'num_violations']
    dev_solution_data['delta_nominal'] = dev_solution_data['nominal_score'].diff().fillna(dev_solution_data['nominal_score'])
    dev_solution_data['delta_real'] = dev_solution_data['real_score'].diff().fillna(dev_solution_data['real_score'])
    dev_solution_data['nominal_gained_per_violation'] = dev_solution_data['delta_nominal'] / dev_solution_data['delta_violations']
    dev_solution_data['real_gained_per_violation'] = dev_solution_data['delta_real'] / dev_solution_data['delta_violations']

    #Basic features
    num_inter_soln = len(dev_solution_data)

    #Adding this because round BbPxDKJjA52xHhYx2 didn't have a complete solution
    try:
        index_first_complete = dev_solution_data.query("is_complete == 1").index[0]
    except: 
        index_first_complete = None

    bool_greedy_exceed_max = dev_solution_data['nominal_score'].max() > task_specs['phase_{}'.format(phase)][dev_task_index]['optimal']
    num_submitted_violations = dev_solution_data['num_violations'].values[-1]
    highest_real_score = dev_solution_data['real_score'].max()
    cdf_submitted_score = np.mean(dev_solution_data['real_score'] <= dev_solution_data['real_score'].values[-1]) 
    first_move_taker = dev_solution_data['subject_id'].values[0]
    last_move_taker = dev_solution_data['subject_id'].values[-1]
    max_solution_violations = dev_solution_data['num_violations'].max()
    avg_solution_violations = dev_solution_data['num_violations'].mean()
    bool_submitted_has_lowest_violations = dev_solution_data.query("num_violations > 0")['num_violations'].min() >= num_submitted_violations

    #Constraints added/removed by player 
    player_constraints_added = dict(dev_solution_data.query("delta_violations > 0").groupby("subject_id")['delta_violations'].sum())
    player_constraints_removed = dict(dev_solution_data.query("delta_violations < 0").groupby("subject_id")['delta_violations'].sum().abs())

    # Longest consecutive chain by each player, and in the round 
    temp_turn_count = 1
    consolidated_actions = []
    for turn_index in range(1,len(dev_solution_data)):
        if dev_solution_data.subject_id.values[turn_index] == dev_solution_data.subject_id.values[turn_index-1]:
            temp_turn_count += 1
        else:
            consolidated_actions.append((dev_solution_data.subject_id.values[turn_index-1], temp_turn_count))
            temp_turn_count = 1

        if turn_index == max(range(1,len(dev_solution_data))):
            consolidated_actions.append((dev_solution_data.subject_id.values[turn_index], temp_turn_count))

    #Checksum
    assert sum([x[1] for x in consolidated_actions]) == len(dev_solution_data)

    consolidated_actions = pd.DataFrame(columns=["subject_id", "num_moves"], data=consolidated_actions)

    #Action streak analysis 
    player_longest_streak = dict(consolidated_actions.groupby("subject_id")['num_moves'].max())
    game_longest_streak = (list(consolidated_actions[consolidated_actions['num_moves'] == consolidated_actions['num_moves'].max()].drop_duplicates().subject_id), consolidated_actions['num_moves'].max())

    #Action counts 
    if index_first_complete is None: 
        pre_complete_action_counts = None
        post_complete_action_counts = None
        prepost_complete_player_count_diff = None
        first_complete_real_score = None
        first_complete_greedy_score = None
        real_score_pace_to_complete = None
        real_score_pace_complete_to_submit = None
        greedy_score_pace_to_complete = None
        greedy_score_pace_complete_to_submit = None
    else:
        pre_complete_action_counts = dict(dev_solution_data.iloc[:dev_solution_data.query("is_complete == 1").index.min()+1].subject_id.value_counts())
        post_complete_action_counts = dict(dev_solution_data.iloc[dev_solution_data.query("is_complete == 1").index.min()+1:].subject_id.value_counts())
        prepost_complete_player_count_diff = len(post_complete_action_counts) - len(pre_complete_action_counts)

        first_complete_real_score = dev_solution_data.query("is_complete == 1").real_score.values[0]
        first_complete_greedy_score = dev_solution_data.query("is_complete == 1").nominal_score.values[0]

        real_score_pace_to_complete = first_complete_real_score / index_first_complete
        real_score_pace_complete_to_submit = (dev_solution_data.real_score.values[-1] - first_complete_real_score) / (num_inter_soln - index_first_complete)

        greedy_score_pace_to_complete = first_complete_greedy_score / index_first_complete
        greedy_score_pace_complete_to_submit = (dev_solution_data.nominal_score.values[-1] - first_complete_greedy_score) / (num_inter_soln - index_first_complete)

    player_action_counts = dict(dev_solution_data['subject_id'].value_counts())
    most_active_player = (list(dev_solution_data['subject_id'].value_counts()[dev_solution_data['subject_id'].value_counts() == dev_solution_data['subject_id'].value_counts().max()].index), 
                        dev_solution_data['subject_id'].value_counts().max())
    least_active_player = (list(dev_solution_data['subject_id'].value_counts()[dev_solution_data['subject_id'].value_counts() == dev_solution_data['subject_id'].value_counts().min()].index),
                        dev_solution_data['subject_id'].value_counts().min())

    # Satisfaction measures 
    num_satisfaction_flags = len([x for x in dev_log if x['verb'] == "playerSatisfaction"])
    if num_satisfaction_flags > 0:
        first_satisfied_player = [x for x in dev_log if x['verb'] == "playerSatisfaction"][0]['subjectId']
        last_satisfied_player = [x for x in dev_log if x['verb'] == "playerSatisfaction"][-1]['subjectId']
    else: 
        first_satisfied_player = None
        last_satisfied_player = None

    #Suggested features 
    num_complete_soln = dev_solution_data.is_complete.sum()
    num_feasible_soln = len(dev_solution_data.query("num_violations == 0"))
    num_feasible_and_complete = len(dev_solution_data.query("num_violations == 0 and is_complete == 1"))
    num_greedy_exceed_max = np.sum(dev_solution_data['nominal_score'] > task_specs['phase_{}'.format(phase)][dev_task_index]['optimal'])
    bool_optimal_found = any(dev_solution_data['real_score'] == task_specs['phase_{}'.format(phase)][dev_task_index]['optimal'])
    bool_submitted_highest_found = cdf_submitted_score == 1
    bool_submitted_highest_greedy = all(dev_solution_data['nominal_score'].values[-1] >= dev_solution_data['nominal_score'].values)
    bool_submitted_is_complete = dev_solution_data['is_complete'].values[-1] == 1
    bool_submitted_is_feasible = dev_solution_data['num_violations'].values[-1] == 0
    bool_submitted_feasible_complete = bool_submitted_is_complete and bool_submitted_is_feasible

    highest_complete_score = dev_solution_data.query("is_complete == 1")['real_score'].max()
    bool_submitted_highest_complete_score = highest_complete_score == dev_solution_data.real_score.values[-1]
    
    bool_optimal_found_and_submitted = bool_optimal_found and bool_submitted_highest_found
    bool_optimal_found_not_submitted = bool_optimal_found and not bool_submitted_highest_found

    gini_moves = gini(list(player_action_counts.values()) + [0]*(3-len(player_action_counts))) #account for cases where a player didnt play at all 

    avg_real_gained_per_violation = dev_solution_data.query("delta_violations > 0")['real_gained_per_violation'].mean()
    avg_real_gained_per_violation_incomplete = dev_solution_data.query("delta_violations > 0 and is_complete == 0")['real_gained_per_violation'].mean()
    avg_real_gained_per_violation_complete = dev_solution_data.query("delta_violations > 0 and is_complete == 1")['real_gained_per_violation'].mean()



    round_summary = {"ROUNDFEAT_SOLNS_num_inter_soln":num_inter_soln, 
                     "ROUNDFEAT_SOLNS_index_first_complete":index_first_complete,
                     "ROUNDFEAT_SCORES_bool_greedy_exceed_max":bool_greedy_exceed_max,
                     "ROUNDFEAT_RULES_num_submitted_violations":num_submitted_violations,
                     "ROUNDFEAT_SCORES_highest_real_score":highest_real_score,
                     "ROUNDFEAT_SCORES_cdf_submitted_score":cdf_submitted_score,
                     "ROUNDFEAT_TMWRK_first_move_taker":first_move_taker,
                     "ROUNDFEAT_TMWRK_last_move_taker":last_move_taker,
                     "ROUNDFEAT_RULES_max_solution_violations":max_solution_violations,
                     "ROUNDFEAT_RULES_avg_solution_violations":avg_solution_violations,
                     "ROUNDFEAT_RULES_bool_submitted_has_lowest_violations":bool_submitted_has_lowest_violations,
                     "ROUNDFEAT_RULES_player_constraints_added":player_constraints_added,
                     "ROUNDFEAT_RULES_player_constraints_removed":player_constraints_removed,
                     "ROUNDFEAT_TMWRK_consolidated_actions":consolidated_actions,
                     "ROUNDFEAT_TMWRK_player_longest_streak":player_longest_streak,
                     "ROUNDFEAT_TMWRK_game_longest_streak":game_longest_streak,
                     "ROUNDFEAT_TMWRK_pre_complete_action_counts":pre_complete_action_counts,
                     "ROUNDFEAT_TMWRK_post_complete_action_counts":post_complete_action_counts,
                     "ROUNDFEAT_TMWRK_player_action_counts":player_action_counts,
                     "ROUNDFEAT_TMWRK_prepost_complete_player_count_diff":prepost_complete_player_count_diff,
                     "ROUNDFEAT_TMWRK_most_active_player":most_active_player,
                     "ROUNDFEAT_TMWRK_least_active_player":least_active_player,
                     "ROUNDFEAT_TMWRK_num_satisfaction_flags":num_satisfaction_flags,
                     "ROUNDFEAT_TMWRK_first_satisfied_player":first_satisfied_player,
                     "ROUNDFEAT_TMWRK_last_satisfied_player":last_satisfied_player,
                     "ROUNDFEAT_SOLNS_num_complete_soln":num_complete_soln,
                     "ROUNDFEAT_SOLNS_num_feasible_soln":num_feasible_soln,
                     "ROUNDFEAT_SOLNS_num_feasible_and_complete":num_feasible_and_complete,
                     "ROUNDFEAT_SCORES_num_greedy_exceed_max":num_greedy_exceed_max,
                     "ROUNDFEAT_SCORES_bool_optimal_found":bool_optimal_found,
                     "ROUNDFEAT_SCORES_bool_submitted_highest_found":bool_submitted_highest_found,
                     "ROUNDFEAT_SCORES_bool_submitted_highest_greedy":bool_submitted_highest_greedy,
                     "ROUNDFEAT_SOLNS_bool_submitted_is_complete":bool_submitted_is_complete,
                     "ROUNDFEAT_SOLNS_bool_submitted_is_feasible":bool_submitted_is_feasible,
                     "ROUNDFEAT_SOLNS_bool_submitted_feasible_complete":bool_submitted_feasible_complete,
                     "ROUNDFEAT_SCORES_highest_complete_score":highest_complete_score,
                     "ROUNDFEAT_SCORES_bool_submitted_highest_complete_score":bool_submitted_highest_complete_score,
                     "ROUNDFEAT_SCORES_bool_optimal_found_and_submitted":bool_optimal_found_and_submitted,
                     "ROUNDFEAT_SCORES_bool_optimal_found_not_submitted":bool_optimal_found_not_submitted,
                     "ROUNDFEAT_TMWRK_gini_moves":gini_moves,
                     "ROUNDFEAT_SCORES_first_complete_real_score":first_complete_real_score,
                     "ROUNDFEAT_SCORES_first_complete_greedy_score":first_complete_greedy_score,
                     "ROUNDFEAT_SCORES_real_score_pace_to_complete":real_score_pace_to_complete,
                     "ROUNDFEAT_SCORES_real_score_pace_complete_to_submit":real_score_pace_complete_to_submit,
                     "ROUNDFEAT_SCORES_greedy_score_pace_to_complete":greedy_score_pace_to_complete,
                     "ROUNDFEAT_SCORES_greedy_score_pace_complete_to_submit":greedy_score_pace_complete_to_submit,
                     "ROUNDFEAT_RULES_avg_real_gained_per_violation":avg_real_gained_per_violation,
                     "ROUNDFEAT_RULES_avg_real_gained_per_violation_incomplete":avg_real_gained_per_violation_incomplete,
                     "ROUNDFEAT_RULES_avg_real_gained_per_violation_complete":avg_real_gained_per_violation_complete}

    return round_summary

#VISUALIZATION
def detailed_plots(df, col):
    sns.pointplot(data=df, x="diff_cat", y=col, hue="group_formation", scale=2)
    plt.title(col)
    
    if len(col) >= 43:
        with sns.plotting_context("notebook", font_scale=1.5, rc={'axes.labelsize':10}):
            sns.catplot(kind="point", x="diff_cat", y=col, hue="group_formation", col="block_cat", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="group_formation", col="skill_cat", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="group_formation", col="social_cat", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="skill_cat", col="group_formation", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="social_cat", col="group_formation", data=df, scale=2)
            sns.catplot(kind="point", x="round_index", y=col, hue="skill_cat", col="group_formation", data=df, scale=2)
            sns.catplot(kind="point", x="round_index", y=col, hue="social_cat", col="group_formation", data=df, scale=2)
    else:
        with sns.plotting_context("notebook", font_scale=1.5):
            sns.catplot(kind="point", x="diff_cat", y=col, hue="group_formation", col="block_cat", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="group_formation", col="skill_cat", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="group_formation", col="social_cat", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="skill_cat", col="group_formation", data=df, scale=2)
            sns.catplot(kind="point", x="diff_cat", y=col, hue="social_cat", col="group_formation", data=df, scale=2)
            sns.catplot(kind="point", x="round_index", y=col, hue="skill_cat", col="group_formation", data=df, scale=2)
            sns.catplot(kind="point", x="round_index", y=col, hue="social_cat", col="group_formation", data=df, scale=2)

def plot_round(round_data, task_spec_dict, round_id):
    round_payoff_dict = task_spec_dict['payoff']
    optimal_score = task_spec_dict['optimal']

    num_inter_solns = len(round_data) 

    plt.figure(figsize=(15, 7))
    plt.plot(range(num_inter_solns), round_data['nominal_score']/optimal_score)
    plt.plot(range(num_inter_solns), round_data['real_score']/optimal_score)

    player_colors = ["red", "green", "blue"]
    unique_players = round_data.subject_id.unique()

    for color_index, player in enumerate(unique_players):
        plt.scatter(round_data.query("subject_id == @player").index, round_data.query("subject_id == @player")['real_score']/optimal_score, color=player_colors[color_index])

    plt.vlines(x=round_data.query("is_complete == 1").index[0], ymin=round_data.loc[:, ["nominal_score", "real_score"]].min().min()/optimal_score, ymax=round_data.loc[:, ["nominal_score", "real_score"]].max().max()/optimal_score, linestyle="--")

    plt.hlines(xmin=0, xmax=len(round_data), y=1, linestyle="--")

    plt.legend(["Nominal Score", "Real Score"]+list(unique_players)+["First Complete Solution"])
    plt.xlabel("Solution Sequence")
    plt.ylabel("Score")
    plt.title("Round ID: {} -- Difficulty: {}".format(round_id, task_spec_dict['difficulty']))
    plt.show()